{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "from sklearn import model_selection\n",
    "from sklearn import datasets\n",
    "from sklearn import metrics\n",
    "import tensorflow as tf\n",
    "import numpy as np\n",
    "from tensorflow.contrib.learn.python.learn.estimators.estimator import SKCompat\n",
    "learn = tf.contrib.learn"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### 1. 自定义softmax回归模型。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "def my_model(features, target):\n",
    "    target = tf.one_hot(target, 3, 1, 0)\n",
    "    \n",
    "    # 计算预测值及损失函数。\n",
    "    logits = tf.contrib.layers.fully_connected(features, 3, tf.nn.softmax)\n",
    "    loss = tf.losses.softmax_cross_entropy(target, logits)\n",
    "    \n",
    "    # 创建优化步骤。\n",
    "    train_op = tf.contrib.layers.optimize_loss(\n",
    "        loss,\n",
    "        tf.contrib.framework.get_global_step(),\n",
    "        optimizer='Adam',\n",
    "        learning_rate=0.01)\n",
    "    return tf.arg_max(logits, 1), loss, train_op"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### 2. 读取数据并将数据转化成TensorFlow要求的float32格式。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "iris = datasets.load_iris()\n",
    "x_train, x_test, y_train, y_test = model_selection.train_test_split(\n",
    "    iris.data, iris.target, test_size=0.2, random_state=0)\n",
    "\n",
    "x_train, x_test = map(np.float32, [x_train, x_test])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### 3. 封装和训练模型，输出准确率。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Using default config.\n",
      "INFO:tensorflow:Using config: {'_save_checkpoints_secs': 600, '_num_ps_replicas': 0, '_keep_checkpoint_max': 5, '_tf_random_seed': None, '_task_type': None, '_environment': 'local', '_is_chief': True, '_cluster_spec': <tensorflow.python.training.server_lib.ClusterSpec object at 0x11487e290>, '_tf_config': gpu_options {\n",
      "  per_process_gpu_memory_fraction: 1\n",
      "}\n",
      ", '_task_id': 0, '_save_summary_steps': 100, '_save_checkpoints_steps': None, '_evaluation_master': '', '_keep_checkpoint_every_n_hours': 10000, '_master': ''}\n",
      "INFO:tensorflow:Create CheckpointSaverHook.\n",
      "INFO:tensorflow:Saving checkpoints for 1 into Models/model_1/model.ckpt.\n",
      "INFO:tensorflow:loss = 1.09918, step = 1\n",
      "INFO:tensorflow:global_step/sec: 677.497\n",
      "INFO:tensorflow:loss = 0.783171, step = 101\n",
      "INFO:tensorflow:global_step/sec: 706.629\n",
      "INFO:tensorflow:loss = 0.696257, step = 201\n",
      "INFO:tensorflow:global_step/sec: 701.641\n",
      "INFO:tensorflow:loss = 0.655656, step = 301\n",
      "INFO:tensorflow:global_step/sec: 562.503\n",
      "INFO:tensorflow:loss = 0.634671, step = 401\n",
      "INFO:tensorflow:global_step/sec: 582.211\n",
      "INFO:tensorflow:loss = 0.622115, step = 501\n",
      "INFO:tensorflow:global_step/sec: 522.269\n",
      "INFO:tensorflow:loss = 0.613794, step = 601\n",
      "INFO:tensorflow:global_step/sec: 583.315\n",
      "INFO:tensorflow:loss = 0.607876, step = 701\n",
      "INFO:tensorflow:Saving checkpoints for 800 into Models/model_1/model.ckpt.\n",
      "INFO:tensorflow:Loss for final step: 0.603483.\n",
      "Accuracy: 100.00%\n"
     ]
    }
   ],
   "source": [
    "classifier = SKCompat(learn.Estimator(model_fn=my_model, model_dir=\"Models/model_1\"))\n",
    "classifier.fit(x_train, y_train, steps=800)\n",
    "\n",
    "y_predicted = [i for i in classifier.predict(x_test)]\n",
    "score = metrics.accuracy_score(y_test, y_predicted)\n",
    "print('Accuracy: %.2f%%' % (score * 100))"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 2",
   "language": "python",
   "name": "python2"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 2
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython2",
   "version": "2.7.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}
